import argparse


def eval_sample_parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str,
                        default="outputs/edm_1",
                        help='Specify model path')
    parser.add_argument(
        '--n_tries', type=int, default=10,
        help='N tries to find stable molecule for gif animation')
    parser.add_argument('--n_nodes', type=int, default=19,
                        help='number of atoms in molecule for gif animation')
    parser.add_argument('--batch_size_gen', type=int, default=100,
                        help='Specify model path')
    return parser


def eval_parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default="outputs/edm_1",
                        help='Specify model path')
    parser.add_argument('--n_samples', type=int, default=100,
                        help='Specify model path')
    parser.add_argument('--batch_size_gen', type=int, default=100,
                        help='Specify model path')
    parser.add_argument('--save_to_xyz', type=eval, default=False,
                        help='Should save samples to xyz files.')

    parser.add_argument('--target_domain', type=str, default='train',
                        help='From where to condition')
    return parser


def parse_args():
    parser = argparse.ArgumentParser(description='E3Diffusion')
    parser.add_argument('--exp_name', type=str, default='debug_10')

    # VQ
    parser.add_argument("--num_e", default=4000, type=int)
    parser.add_argument("--commitment_weight", default=0.1, type=float)
    # VQ parameters end

    # DAMG args
    parser.add_argument("--is_da_mg", type=bool, default=False, help='Domain adaptive generation')
    parser.add_argument("--self_condition_nf", default=4, type=int)
    parser.add_argument("--score_regularization_para", default=0.5, type=float)
    parser.add_argument("--property_coef", default=0.01, type=float)
    parser.add_argument('--source_dataset', type=str, default='COMPAS',
                        help='qm9 | geom | COMPAS')
    parser.add_argument('--target_dataset', type=str, default='geom',
                        help='qm9 | geom')
    parser.add_argument("--target_data_size", default=10, type=int)
    parser.add_argument('--generators_path', type=str, default='outputs/exp_cond_alpha_pretrained')
    parser.add_argument('--da_epoch', type=int, default=3000,
                        help='number of epochs the model for adaptive generation')

    parser.add_argument("--mask_ratio", default=0.75, type=float, help='Mask ratio for condition')
    # parser.add_argument("--is_shuffle_condition", type=bool, default=True, help='Shuffle condition')
    # Latent Diffusion args
    parser.add_argument('--train_diffusion', action='store_true',
                        help='Train second stage LatentDiffusionModel model')
    parser.add_argument('--ae_path', type=str, default=None,
                        help='Specify first stage model path')
    parser.add_argument('--trainable_ae', action='store_true',
                        help='Train first stage AutoEncoder model')

    # PAS dataset args
    # mol_data param
    parser.add_argument("--pas_data", type=bool, default=False)
    parser.add_argument("--pas_dataset", default="cata", type=str)
    parser.add_argument("--rings_graph", type=bool, default=True)
    parser.add_argument("--max-nodes", default=11, type=str)
    parser.add_argument("--normalize", type=bool, default=True)
    parser.add_argument("--sample-rate", type=float, default=1)

    # VAE args
    parser.add_argument('--latent_nf', type=int, default=4,
                        help='number of latent features')
    parser.add_argument('--kl_weight', type=float, default=0.01,
                        help='weight of KL term in ELBO')

    parser.add_argument('--model', type=str, default='egnn_dynamics',
                        help='our_dynamics | schnet | simple_dynamics | '
                             'kernel_dynamics | egnn_dynamics |gnn_dynamics')
    parser.add_argument('--probabilistic_model', type=str, default='diffusion',
                        help='diffusion')

    # Training complexity is O(1) (unaffected), but sampling complexity is O(steps).
    parser.add_argument('--diffusion_steps', type=int, default=500)
    parser.add_argument('--diffusion_noise_schedule', type=str, default='polynomial_2',
                        help='learned, cosine')
    parser.add_argument('--diffusion_noise_precision', type=float, default=1e-5,
                        )
    parser.add_argument('--diffusion_loss_type', type=str, default='l2',
                        help='vlb, l2')

    parser.add_argument('--n_epochs', type=int, default=200)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=2e-4)
    parser.add_argument('--brute_force', type=eval, default=False,
                        help='True | False')
    parser.add_argument('--actnorm', type=eval, default=True,
                        help='True | False')
    parser.add_argument('--break_train_epoch', type=eval, default=False,
                        help='True | False')
    parser.add_argument('--dp', type=eval, default=True,
                        help='True | False')
    parser.add_argument('--condition_time', type=eval, default=True,
                        help='True | False')
    parser.add_argument('--clip_grad', type=eval, default=True,
                        help='True | False')
    parser.add_argument('--trace', type=str, default='hutch',
                        help='hutch | exact')
    parser.add_argument('--analyze_during_train', type=bool, default=False,
                        help='True | False')
    # EGNN args -->
    parser.add_argument('--n_layers', type=int, default=6,
                        help='number of layers')
    parser.add_argument('--inv_sublayers', type=int, default=1,
                        help='number of layers')
    parser.add_argument('--nf', type=int, default=128,
                        help='number of layers')
    parser.add_argument('--tanh', type=eval, default=True,
                        help='use tanh in the coord_mlp')
    parser.add_argument('--attention', type=eval, default=True,
                        help='use attention in the EGNN')
    parser.add_argument('--norm_constant', type=float, default=1,
                        help='diff/(|diff| + norm_constant)')
    parser.add_argument('--sin_embedding', type=eval, default=False,
                        help='whether using or not the sin embedding')
    # <-- EGNN args

    # parser.add_argument("--num_e", default=4000, type=int)

    parser.add_argument('--ode_regularization', type=float, default=1e-3)
    parser.add_argument('--dataset', type=str, default='qm9',
                        help='qm9 | qm9_second_half (train only on the last 50K samples of the training dataset) | qm9_few_shot')
    parser.add_argument('--datadir', type=str, default='qm9/temp',
                        help='qm9 directory')
    parser.add_argument('--filter_n_atoms', type=int, default=None,
                        help='When set to an integer value, QM9 will only contain molecules of that amount of atoms')
    parser.add_argument('--dequantization', type=str, default='argmax_variational',
                        help='uniform | variational | argmax_variational | deterministic')
    parser.add_argument('--n_report_steps', type=int, default=1)
    parser.add_argument('--wandb_usr', type=str)
    parser.add_argument('--no_wandb', action='store_true', help='Disable wandb')
    parser.add_argument('--online', type=bool, default=True, help='True = wandb online -- False = wandb offline')
    parser.add_argument('--shuffle_self_condition', type=bool, default=False, help='Shuffle self condition features or not')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='enables CUDA training')
    parser.add_argument('--save_model', type=eval, default=True,
                        help='save model')
    parser.add_argument('--generate_epochs', type=int, default=1,
                        help='save model')
    parser.add_argument('--num_workers', type=int, default=0, help='Number of worker for the dataloader')
    parser.add_argument('--test_epochs', type=int, default=10)
    parser.add_argument('--data_augmentation', type=eval, default=False, help='use attention in the EGNN')
    parser.add_argument("--conditioning", nargs='+', default=[],
                        help='arguments : homo | lumo | alpha | gap | mu | Cv')
    parser.add_argument('--resume', type=str, default=None,
                        help='')
    parser.add_argument('--start_epoch', type=int, default=0,
                        help='')
    parser.add_argument('--ema_decay', type=float, default=0.999,
                        help='Amount of EMA decay, 0 means off. A reasonable value'
                             ' is 0.999.')
    parser.add_argument('--augment_noise', type=float, default=0)
    parser.add_argument('--n_stability_samples', type=int, default=500,
                        help='Number of samples to compute the stability')
    parser.add_argument('--normalize_factors', type=eval, default=[1, 4, 1],
                        help='normalize factors for [x, categorical, integer]')
    parser.add_argument('--remove_h', action='store_true')
    parser.add_argument('--include_charges', type=eval, default=True,
                        help='include atom charge or not')
    parser.add_argument('--visualize_every_batch', type=int, default=1e8,
                        help="Can be used to visualize multiple times per epoch")
    parser.add_argument('--normalization_factor', type=float, default=1,
                        help="Normalize the sum aggregation of EGNN")
    parser.add_argument('--aggregation_method', type=str, default='sum',
                        help='"sum" or "mean"')
    parser.add_argument('--filter_molecule_size', type=int, default=None,
                        help="Only use molecules below this size.")
    parser.add_argument('--sequential', action='store_true',
                        help='Organize mol_data by size to reduce average memory usage.')
    args = parser.parse_args()
    return args
